package com.ihome.framework.core.mdc;

import java.lang.reflect.Method;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.Executor;
import java.util.concurrent.Future;

import org.aopalliance.intercept.MethodInvocation;
import org.slf4j.MDC;
import org.springframework.aop.support.AopUtils;
import org.springframework.core.BridgeMethodResolver;
import org.springframework.scheduling.annotation.AnnotationAsyncExecutionInterceptor;
import org.springframework.util.ClassUtils;

public class MDCAwareAnnotationAsyncExecutionInterceptor extends AnnotationAsyncExecutionInterceptor {

	public MDCAwareAnnotationAsyncExecutionInterceptor(Executor defaultExecutor) {
		super(defaultExecutor);
	}

	@Override
	public Object invoke(final MethodInvocation invocation) throws Throwable {
		Class<?> targetClass = (invocation.getThis() != null ? AopUtils.getTargetClass(invocation.getThis()) : null);
		Method specificMethod = ClassUtils.getMostSpecificMethod(invocation.getMethod(), targetClass);
		specificMethod = BridgeMethodResolver.findBridgedMethod(specificMethod);

		Future<?> result = determineAsyncExecutor(specificMethod).submit(buildCallable(invocation, specificMethod));

		if (Future.class.isAssignableFrom(invocation.getMethod().getReturnType())) {
			return result;
		} else {
			return null;
		}
	}

	@Override
	protected void handleError(Throwable ex, Method method, Object... params) throws Exception {
		super.handleError(ex, method, params);
	}

	protected Callable<?> buildCallable(final MethodInvocation invocation, Method specificMethod) {
		return new MDCAwareCallable(invocation, MDC.getCopyOfContextMap(), specificMethod);
	}

	private class MDCAwareCallable implements Callable<Object> {

		private final MethodInvocation invocation;

		private final Map<String, String> mdcContextMap;

		private final Method method;

		public MDCAwareCallable(final MethodInvocation invocation, Map<String, String> mdcContextMap,
				Method specificMethod) {
			this.invocation = invocation;
			this.mdcContextMap = mdcContextMap;
			this.method = specificMethod;
		}

		@Override
		public Object call() throws Exception {
			// 备份
			Map<String, String> oldContextMap = MDC.getCopyOfContextMap();
			try {
				MDC.setContextMap(mdcContextMap);
				Object result = invocation.proceed();
				if (result instanceof Future) {
					return ((Future<?>) result).get();
				}
			} catch (Throwable ex) {
				handleError(ex, method, invocation.getArguments());
			} finally {
				// 恢复
				MDC.setContextMap(oldContextMap);
			}
			return null;
		}
	}
}
